import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../'))

import torch
import json
import cv2
import random
import time
import argparse
import numpy as np
import warnings
import torchvision
import glob
from PIL import Image
from collections import OrderedDict, defaultdict
from torch.utils.data import ConcatDataset

from image_synthesis.utils.io import load_yaml_config
from image_synthesis.utils.misc import instantiate_from_config
from image_synthesis.utils.cal_metrics import get_PSNR, get_mse_loss, get_l1_loss, get_SSIM
from image_synthesis.modeling.build import build_model
from image_synthesis.utils.misc import format_seconds
from image_synthesis.distributed.launch import launch
from image_synthesis.distributed.distributed import reduce_dict, synchronize, all_gather
from image_synthesis.utils.misc import get_model_parameters_info, get_model_buffer

def image_post_process(image):
    def convert(t):
        t = np.asarray(t)
        t = t.clip(0, 255).astype(np.uint8)
        return t
    if isinstance(image, (torch.Tensor, np.ndarray)):
        image = convert(image)
    elif isinstance(image, dict):
        for k in image:
            image[k] = convert(image[k])
    elif isinstance(image, list):
        for k in range(len(image)):
            image[k] = convert(image[k])
    else:
        raise ValueError

    return image

def save_list_results(images, save_root, batch_idx, local_rank=0):
    '''
    images: list of tensors, each tensor is with [N, C, H, W]
    '''
    os.makedirs(save_root, exist_ok=True)
    N, _, H, W = images[0].shape
    images = [im.permute(0, 2, 3, 1).numpy() for im in images]
    for i in range(N):
        file_name = '{:6d}_rank{}.png'.format(batch_idx*N + i, local_rank)
        im = Image.fromarray(image_post_process(images[0][i]))
        for j in range(1, len(images)):
            im_tmp = Image.new(im.mode, ((j+1)*W, H)) 
            im_tmp.paste(im, box=(0,0))
            im_tmp.paste(Image.fromarray(image_post_process(images[j][i])), box=(j*W, 0))
            im = im_tmp
        im.save(os.path.join(save_root, file_name))
        print('saved in {}'.format(os.path.join(save_root, file_name)))

def save_image_dict(images, save_dir, batch_idx, local_rank=0, make_grid=True, ignored_keys=None, suffix=None):
    for k, v in images.items():
        if ignored_keys is not None and k in ignored_keys:
            continue
        if suffix is None:
            save_dir_ = os.path.join(save_dir, k)
        else:
            save_dir_ = os.path.join(save_dir, k+'_'+suffix)
        os.makedirs(save_dir_, exist_ok=True)
        save_path = os.path.join(save_dir_, '{:06d}_rank{}'.format(batch_idx, local_rank))
        if torch.is_tensor(v) and v.dim() == 4 and v.shape[1] in [1, 3]: # image
            im = v
            im = im.to(torch.uint8) # N x 3 x H x W

            # save images
            if make_grid:
                im_grid = torchvision.utils.make_grid(im)
                im_grid = im_grid.permute(1, 2, 0).to('cpu').numpy()
                im_grid = Image.fromarray(im_grid)

                im_grid.save(save_path + '.png')
                print('save {} to {}'.format(k, save_path+'.png'))
            else:
                if v.shape[1] == 3: # for this, we only save generated images
                    for i in range(im.shape[0]):
                        # import pdb; pdb.set_trace()
                        im_ = im[i].permute(1, 2, 0).to('cpu').numpy()
                        im_ = Image.fromarray(im_)
                        save_path_ = save_path + '_{}.png'.format(i)
                        im_.save(save_path_)
                    print('save {} to {}'.format(k, save_path_))

        else: # may be other values, such as 
            with open(save_path+'.txt', 'a') as f:
                f.write(str(v)+'\n')
                f.close()
            print('save {} to {}'.format(k, save_path+'.txt'))


def save_image_pair(images, save_root, batch_idx, local_rank=0):
    '''
    images: list or tuple, each element in it is a tensor, [N, C, H, W]
    '''
    os.makedirs(save_root, exist_ok=True)
    
    for i in range(len(images)):
        im_grid_tmp = torchvision.utils.make_grid(images[i], nrow=images[i].shape[0]) # 3, H, W
        if i == 0:
            im = im_grid_tmp
        else:
            im = torch.cat((im, im_grid_tmp), dim=1)

    # im1_grid = torchvision.utils.make_grid(images1, nrow=images1.shape[0]) # 3, H, W
    # im2_grid = torchvision.utils.make_grid(images2, nrow=images1.shape[0]) # 3, H, W
    # im = torch.cat((im1_grid, im2_grid), dim=1)
    im = im.permute(1, 2, 0).to('cpu').numpy().astype(np.uint8)
    im = Image.fromarray(im)
    file_name = '{}_rank{}.png'.format(str(batch_idx).zfill(6), local_rank)
    im.save(os.path.join(save_root, file_name))
    print('saved in {}'.format(os.path.join(save_root, file_name)))

def get_model(model_name='2020-11-09T13-33-36_faceshq_vqgan'):
    if os.path.isfile(model_name):
        # import pdb; pdb.set_trace()
        if model_name.endswith(('.pth', '.ckpt')):
            model_path = model_name
            config_path = os.path.join(os.path.dirname(model_name), '..', 'configs', 'config.yaml')
        elif model_name.endswith('.yaml'):
            config_path = model_name
            model_path = os.path.join(os.path.dirname(model_name), '..', 'checkpoint', 'last.pth')
        else:
            raise RuntimeError(model_name)
        
        if 'OUTPUT' in model_name: # pretrained model
            model_name = model_path.split(os.path.sep)[-3]
        else: # just give a config file, such as test_openai_dvae.yaml, which is no need to train, just test
            model_name = os.path.basename(config_path).replace('.yaml', '')
    else:
        model_path = os.path.join('OUTPUT', model_name, 'checkpoint', 'last.pth')
        config_path = os.path.join(os.path.join('OUTPUT', model_name, 'configs', 'config.yaml'))
    print (config_path)
    config = load_yaml_config(config_path)
    model = build_model(config)
    model_parameters = get_model_parameters_info(model)
    # import pdb; pdb.set_trace()
    print(model_parameters)
    if os.path.exists(model_path):
        ckpt = torch.load(model_path, map_location="cpu")
    else:
        ckpt = {}
    if 'last_epoch' in ckpt:
        epoch = ckpt['last_epoch']
    elif 'epoch' in ckpt:
        epoch = ckpt['epoch']
    else:
        epoch = 0

    if 'model' in ckpt:
        missing, unexpected = model.load_state_dict(ckpt["model"], strict=False)
    elif 'state_dict' in ckpt:
        missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)
    else:
        missing, unexpected = [], []
        print("====> Warning! No pretrained model!")

    print('Model missing keys:\n', missing)
    print('Model unexpected keys:\n', unexpected)
    
    return model



def get_model_and_dataset(args=None, model_name='2020-11-09T13-33-36_faceshq_vqgan'):
    if os.path.isfile(model_name):
        # import pdb; pdb.set_trace()
        if model_name.endswith(('.pth', '.ckpt')):
            model_path = model_name
            config_path = os.path.join(os.path.dirname(model_name), '..', 'configs', 'config.yaml')
        elif model_name.endswith('.yaml'):
            config_path = model_name
            model_path = os.path.join(os.path.dirname(model_name), '..', 'checkpoint', 'last.pth')
        else:
            raise RuntimeError(model_name)
        
        if 'OUTPUT' in model_name: # pretrained model
            model_name = model_path.split(os.path.sep)[-3]
        else: # just give a config file, such as test_openai_dvae.yaml, which is no need to train, just test
            model_name = os.path.basename(config_path).replace('.yaml', '')
    else:
        model_path = os.path.join('OUTPUT', model_name, 'checkpoint', 'last.pth')
        config_path = os.path.join(os.path.join('OUTPUT', model_name, 'configs', 'config.yaml'))

    args.model_path = model_path
    args.config_path = config_path

    config = load_yaml_config(config_path)
    model = build_model(config)
    model_parameters = get_model_parameters_info(model)
    # import pdb; pdb.set_trace()
    print(model_parameters)
    if os.path.exists(model_path):
        ckpt = torch.load(model_path, map_location="cpu")
    else:
        ckpt = {}
    if 'last_epoch' in ckpt:
        epoch = ckpt['last_epoch']
    elif 'epoch' in ckpt:
        epoch = ckpt['epoch']
    else:
        epoch = 0

    if 'model' in ckpt:
        missing, unexpected = model.load_state_dict(ckpt["model"], strict=False)
    elif 'state_dict' in ckpt:
        missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)
    else:
        missing, unexpected = [], []
        print("====> Warning! No pretrained model!")

    print('Model missing keys:\n', missing)
    print('Model unexpected keys:\n', unexpected)

    if args is not None and args.ema and 'ema' in ckpt:
        print("Evaluate EMA model")
        # import pdb; pdb.set_trace()
        if hasattr(model, 'get_ema_model') and callable(model.get_ema_model):
            ema_model = model.get_ema_model()
            # missing, unexpected = model.get_ema_model().load_state_dict(ckpt['ema'], strict=False)
        else:
            ema_model = model
            # missing, unexpected = model.load_state_dict(ckpt['ema'], strict=False)
        
        if args.ema_no_buffer:
            ema_param = OrderedDict()
            for n, p in ema_model.named_parameters():
                ema_param[n] = ckpt['ema'][n]
            missing, unexpected = ema_model.load_state_dict(ema_param, strict=False)
            skipped_buffer_name = []
            for k in ckpt['ema'].keys():
                if k not in ema_param:
                    skipped_buffer_name.append(k)
            if len(skipped_buffer_name) == 0:
                raise ValueError('No buffer found in ema model, please set args.ema_no_buffer to False')
            print('EMA model skipped buffer:\n', skipped_buffer_name)
        else:
            missing, unexpected = ema_model.load_state_dict(ckpt['ema'], strict=False)

    data_type = 'validation_datasets'
    if args is not None and args.data_type == 'train':
        data_type = 'train_datasets'

    val_dataset = []
    for ds_cfg in config['dataloader'][data_type]:
        ds = instantiate_from_config(ds_cfg)
        val_dataset.append(ds)
    if len(val_dataset) > 1:
        val_dataset = ConcatDataset(val_dataset)
    else:
        val_dataset = val_dataset[0]
    
    return {'model': model, 'data': val_dataset, 'epoch': epoch, 'model_name': model_name, 'parameter': model_parameters}


def update_ema_model_buffer(local_rank, args=None, model=None, data=None):
    if model is None or data is None:
        info = get_model_and_dataset(args=args, model_name=args.name)
        model = info['model']
        data = info['data']

    
    for p in model.parameters():
        p.requires_grad = False

    # import pdb; pdb.set_trace()

    num_workers = 0 # max( 2, batch_size // 4)
    if args is not None and args.distributed:
        print('DDP')
        p.requires_grad = True # only set the last one parameter rewuires grad for DDP
        model = model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

        sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=False)
        dataloader = torch.utils.data.DataLoader(data, 
                                             batch_size=args.batch_size, 
                                             shuffle=True, #(val_sampler is None),
                                             num_workers=num_workers, 
                                             pin_memory=True, 
                                             sampler=sampler, 
                                             drop_last=True)

        if hasattr(model.module, 'get_ema_model') and callable(model.get_ema_model):
            model.eval()
            ema_model = model.module.get_ema_model()
            ema_model.train()    
        else:
            ema_model = model.module
            model.train()
    else:
        model = model.cuda()
        dataloader = torch.utils.data.DataLoader(data, batch_size=args.batch_size, num_workers=num_workers, shuffle=True, drop_last=False)
        if hasattr(model, 'get_ema_model') and callable(model.get_ema_model):
            model.eval()
            ema_model = model.get_ema_model()
            ema_model.train()
        else:
            ema_model = model
            ema_model.train()

    model.train()
    save_interval = 1000
    save_path = os.path.join(os.path.dirname(args.model_path), 'ema_updated_buffer_{rank}_{itr}.pth')
    num_batchs = len(data) // args.batch_size
    for itr, batch in enumerate(dataloader):
        print('{}/{}'.format(itr, num_batchs))
        for k, v in batch.items():
            if torch.is_tensor(v):
                batch[k] = v.cuda()
        
        input = {
                'batch': batch,
                'return_loss': False
                }
        output = model(**input)

        if (itr+1) % save_interval == 0 or itr == num_batchs - 1:
            buffer = get_model_buffer(ema_model)
            save_ = {
                'model': buffer,
                'iteration': itr,
                'batch_size': args.batch_size,
            }
            torch.save(save_, save_path.format(rank=local_rank, itr=itr))
            torch.save(save_, save_path.format(rank=local_rank, itr='last'))
            print('saved to {}'.format(save_path.format(rank=local_rank, itr='last')))
        # torch.save(save_, save_path.format(rank=local_rank, itr='last'))
        # print('saved to {}'.format(save_path.format(rank=local_rank, itr='last')))

def caculate_flops_and_params(local_rank=0, args=None):
    from thop import profile, clever_format
    info = get_model_and_dataset(args=args, model_name=args.name)
    model = info['model']
    data = info['data']
    
    model = model.cuda()
    model.train()
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, num_workers=1, shuffle=False, drop_last=False)
    for itr, batch in enumerate(dataloader):
        input = {
                'batch': batch,
                'return_loss': True,
                }
        # import pdb; pdb.set_trace()
        macs, params = profile(model, inputs=(batch,))
        
        params = torch.DoubleTensor([0]).to(model.device)
        for p in model.parameters():
            params += p.nelement()
            
        macs, params = clever_format([macs, params], "%.3f")
        break
    print('rank: {},'.format(local_rank), macs, params)
    
def inference_token(local_rank=0, args=None):
    # used  for image based auto encoder
    info = get_model_and_dataset(args=args, model_name=args.name)
    model = info['model']
    data = info['data']
    epoch = info['epoch']
    model_name = info['model_name']

    num_workers = 0 # max( 2, batch_size // 4)
    if args is not None and args.distributed:
        print('DDP')
        model = model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

        sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=False)
        dataloader = torch.utils.data.DataLoader(data, 
                                             batch_size=args.batch_size, 
                                             shuffle=False, #(val_sampler is None),
                                             num_workers=num_workers, 
                                             pin_memory=True, 
                                             sampler=sampler, 
                                             drop_last=True)
    else:
        model = model.cuda()
        dataloader = torch.utils.data.DataLoader(data, batch_size=args.batch_size, num_workers=num_workers, shuffle=False, drop_last=False)
    
    num_batch = len(data) // (args.batch_size * args.world_size)
    print('images:', len(data)//args.world_size)
    total_loss = {"mse_loss": 0.0, "psnr": 0.0, "l1_loss": 0.0, "ssim": 0.0}
    total_batch = 0.0
    # save images
    save_root = os.path.join(args.save_dir, model_name+'_{}'.format(args.data_type))
    if args.ema:
        if args.ema_no_buffer:
            save_root += '_emaNoBuffer'
        else:
            save_root += '_ema'
    save_root = save_root + '_e{}'.format(epoch)

    print('results will be saved in {}'.format(save_root))
    save_count = 10
    save_token = True 

    token_freq = OrderedDict()
    if token_freq is not None:
        num_tokens = model.module.get_number_of_tokens() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.get_number_of_tokens()
        for i in range(num_tokens):
            token_freq[i] = torch.tensor(0.0).cuda()

    for i, data_i in enumerate(dataloader):
        print("{}/{}".format(i, num_batch))

        with torch.no_grad():
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                img = data_i['image'].to(model.module.device)
                token = model.module.get_tokens(img)
                rec = model.module.decode(token['token'])
            else:
                img = data_i['image'].to(model.device)
                print(img.min(), img.max())
                #img = img[:,:,:128,:128]
                token = model.get_tokens(img)
                print(token)
                
                # save tokens
                if i < save_count and save_token:
                    os.makedirs(save_root, exist_ok=True)
                    token_save_path = os.path.join(save_root, 'token_rank{}_batch{}.pth'.format(local_rank, i))
                    diff_token_save_path = os.path.join(save_root, 'diff_token_rank{}_batch{}.pth'.format(local_rank, i))
                    torch.save(token['token'], token_save_path)
                    #torch.save(token_diff, diff_token_save_path)
                    #torch.save(token, token_save_path)
                    print('token saved in {}'.format(token_save_path))


def inference_reconstruction(local_rank=0, args=None):
    # used  for image based auto encoder
    info = get_model_and_dataset(args=args, model_name=args.name)
    model = info['model']
    data = info['data']
    epoch = info['epoch']
    model_name = info['model_name']

    num_workers = 0 # max( 2, batch_size // 4)
    if args is not None and args.distributed:
        print('DDP')
        model = model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

        sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=False)
        dataloader = torch.utils.data.DataLoader(data, 
                                             batch_size=args.batch_size, 
                                             shuffle=False, #(val_sampler is None),
                                             num_workers=num_workers, 
                                             pin_memory=True, 
                                             sampler=sampler, 
                                             drop_last=True)
    else:
        model = model.cuda()
        dataloader = torch.utils.data.DataLoader(data, batch_size=args.batch_size, num_workers=num_workers, shuffle=False, drop_last=False)
    
    num_batch = len(data) // (args.batch_size * args.world_size)
    print('images:', len(data)//args.world_size)
    total_loss = {"mse_loss": 0.0, "psnr": 0.0, "l1_loss": 0.0, "ssim": 0.0}
    total_batch = 0.0
    # save images
    save_root = os.path.join(args.save_dir, model_name+'_{}'.format(args.data_type))
    if args.ema:
        if args.ema_no_buffer:
            save_root += '_emaNoBuffer'
        else:
            save_root += '_ema'
    save_root = save_root + '_e{}'.format(epoch)

    print('results will be saved in {}'.format(save_root))
    save_count = 10
    save_token = True 

    token_freq = OrderedDict()
    if token_freq is not None:
        num_tokens = model.module.get_number_of_tokens() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.get_number_of_tokens()
        for i in range(num_tokens):
            token_freq[i] = torch.tensor(0.0).cuda()

    for i, data_i in enumerate(dataloader):
        print("{}/{}".format(i, num_batch))

        with torch.no_grad():
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                img = data_i['image'].to(model.module.device)
                token = model.module.get_tokens(img)
                #if isinstance(model.module, VQVAE2):
                #    top_token = token['top_token'].view(*token['top_token_shape'])
                #    bottom_token = token['bottom_token'].view(*token['bottom_token_shape'])
                #    rec = model.module.decode(top_token, bottom_token)
                #else:
                rec = model.module.decode(token['token'])
            else:
                img = data_i['image'].to(model.device)
                #img = img[:,:,:128,:128]
                token = model.get_tokens(img)
                #img_lt = img[:,:,:128,:128]
                #token_ = token['token']
                #token_part = token_.reshape(args.batch_size, 32, 32)
                #token_part = token_part[:, :16, :16] 
                ##token_part = token_part.reshape(args.batch_size, 256)
                #token_lf = model.get_tokens(img_lt)
                #token_lf_ = token_lf['token']
                #token_lf_ = token_lf_.reshape(args.batch_size, 16, 16)
                #token_diff = torch.eq(token_part, token_lf_)
                #token_diff_percent = torch.sum(torch.eq(token_part, token_lf_)).item()/token_part.nelement() 
                #print('{} percent tokens unchanged'.format(token_diff_percent))
                #token_diff_percent = torch.sum(torch.eq(token_part[:,:12,:12], token_lf_[:,:12,:12])).item()/token_part[:,:12,:12].nelement() 
                #print('{} percent tokens unchanged in lefttop 12x12'.format(token_diff_percent))
                #token_diff_percent = torch.sum(torch.eq(token_part[:,:8,:8], token_lf_[:,:8,:8])).item()/token_part[:,:8,:8].nelement() 
                #print('{} percent tokens unchanged in lefttop 8x8'.format(token_diff_percent))
                #if isinstance(model, VQVAE2):
                #    top_token = token['top_token'].view(*token['top_token_shape'])
                #    bottom_token = token['bottom_token'].view(*token['bottom_token_shape'])
                #    rec = model.decode(top_token, bottom_token)
                #else:
                rec = model.decode(token['token'])
                
                # save tokens
                if i < save_count and save_token:
                    os.makedirs(save_root, exist_ok=True)
                    token_save_path = os.path.join(save_root, 'token_rank{}_batch{}.pth'.format(local_rank, i))
                    diff_token_save_path = os.path.join(save_root, 'diff_token_rank{}_batch{}.pth'.format(local_rank, i))
                    torch.save(token['token'], token_save_path)
                    #torch.save(token_diff, diff_token_save_path)
                    #torch.save(token, token_save_path)
                    print('token saved in {}'.format(token_save_path))

            mse_loss = get_mse_loss(img, rec)
            l1_loss = get_l1_loss(img, rec)
            psnr = get_PSNR(img, rec)
            ssim = get_SSIM(img, rec)

        total_loss['mse_loss'] += mse_loss * img.shape[0]
        total_loss['l1_loss'] += l1_loss * img.shape[0]
        total_loss['psnr'] += psnr * img.shape[0]
        total_loss['ssim'] += ssim * img.shape[0]

        total_batch += img.shape[0]
        if i < save_count:
            save_image_pair([img, rec], save_root, i, local_rank=local_rank)
        
        # get token count
        # import pdb; pdb.set_trace()
        if token_freq is not None and 'token_index' in token:
            token_index_list = token['token_index'].view(-1).tolist()
            for idx in token_index_list:
                token_freq[int(idx)] += 1

    # save token frequency
    if token_freq is not None:
        token_freq = reduce_dict(token_freq, average=False)
        token_idx = []
        token_count = []
        for k, v in token_freq.items():
            token_idx.append(k)
            token_count.append(int(v))
        token_freq_ = OrderedDict()
        index = np.argsort(token_count)
        for i in range(len(index)-1, -1, -1):
            i = index[i]
            cnt = token_count[i]
            idx = token_idx[i]
            token_freq_[idx] = cnt
        token_freq_path = os.path.join(save_root, 'token_freqency.json')
        json.dump(token_freq_, open(token_freq_path, 'w'), indent=4)

    synchronize()
    total_loss = reduce_dict(total_loss, average=False)
    total_batch = sum(all_gather(total_batch))
    for k in total_loss:
        total_loss[k] = total_loss[k] / total_batch

    if local_rank == 0:
        for k in total_loss.keys():
            total_loss[k] = float(total_loss[k])
        loss_path = os.path.join(save_root, 'total_loss.json')
        os.makedirs(os.path.dirname(loss_path), exist_ok=True)
        json.dump(total_loss, open(loss_path, 'w'), indent=4)
        print(total_loss)

        # save model parameters info
        info_path = (os.path.join(save_root, 'model_parameters.json'))
        json.dump(info['parameter'], open(info_path, 'w'), indent=4)


def inference_generate_sample_with_condition(local_rank=0, args=None):
    info = get_model_and_dataset(args=args, model_name=args.name)
    model = info['model']
    data = info['data']
    epoch = info['epoch']
    model_name = info['model_name']

    if args is not None and args.distributed:
        print('DDP')
        model = model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

        sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=True)
        dataloader = torch.utils.data.DataLoader(data, 
                                             batch_size=1, 
                                             shuffle=False,
                                             num_workers=1, 
                                             pin_memory=True, 
                                             sampler=sampler, 
                                             drop_last=True)
    else:
        model = model.cuda()
        dataloader = torch.utils.data.DataLoader(data, batch_size=1, num_workers=1, shuffle=True, drop_last=False)
    print('images:', len(data)//args.world_size)
    
    count_cond = 100 // (args.world_size if args is not None else 1)
    count_per_cond = 50
    return_att_weight = False # True #False
    filter_ratio = [0.3, 0.5, 0.5, 0.7, 0.9] #[0.3, 0.5, 50, 100]
    content_ratio = [0.0] #, 0.5] # [0.0, 0.0, 0.3, 0.7]
    chosed_fr = []
    chosed_cr = []

    # save images
    save_root = os.path.join(args.save_dir, model_name+'_{}'.format(args.data_type))
    # import pdb; pdb.set_trace()
    if args.ema:
        if args.ema_no_buffer:
            save_root += '_emaNoBuffer'
        else:
            save_root += '_ema'
    save_root = save_root + '_e{}_generate'.format(epoch)

    os.makedirs(save_root, exist_ok=True)
    print('results will be saved in {}'.format(save_root))
    
    start_gen = time.time()
    for i, data_i in enumerate(dataloader):
        if i > count_cond:
            break

        if False:
            from torchvision import transforms
            transforms.ToPILImage()(data_i['image'][0]/255.0).save("temp.png")
            print(data_i['text'])

        condition_info = model.condition_info if not isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.module.condition_info
        condition = data_i[condition_info['key']][0]

        # condition = 'a cartoon illustration of a yellow devil'
        # # condition = 'bride at the vector art illustration'
        # # condition = 'a photo of cat.'
        # # condition = 'it is an apple!'
        # data_i[condition_info['key']][0] = condition

        if torch.is_tensor(condition):
            if condition.numel() == 1:
                str_cond = str(condition.view(-1).numpy()[0])
            else:
                str_cond = str(condition)
        else:
            str_cond = str(condition)
        
        save_root_ = os.path.join(save_root, str_cond)
        os.makedirs(save_root_, exist_ok=True)

        # save_condition
        with open(os.path.join(save_root_, 'condition.txt'), 'w') as fc:
            fc.write(str_cond)
            fc.close()

        # save gt image and reconstruction
        gt_count_ = len(glob.glob(os.path.join(save_root_, 'gt_image_rank*.png')))
        if gt_count_ == 0:
            gt_im = Image.fromarray(data_i['image'][0].permute(1, 2, 0).to('cpu').numpy().astype(np.uint8))
            save_path = os.path.join(save_root_, 'gt_image_rank{}.png'.format(local_rank))
            gt_im.save(save_path)

            if gt_im.size != (256, 256) and 'image256' in data_i:
                gt_im256 = Image.fromarray(data_i['image256'][0].permute(1, 2, 0).to('cpu').numpy().astype(np.uint8))
                save_path = os.path.join(save_root_, 'gt_image256_rank{}.png'.format(local_rank))
                gt_im256.save(save_path) 
                
        gt_rec_count_ = len(glob.glob(os.path.join(save_root_, 'gt_image_rec_rank*.png')))
        if gt_rec_count_ == 0:
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                rec = model.module.reconstruct(input=data_i['image']) # B x C x H x W
            else:
                rec = model.reconstruct(input=data_i['image']) # B x C x H x W
            rec_im = Image.fromarray(rec[0].permute(1, 2, 0).to('cpu').numpy().astype(np.uint8))
            save_path = os.path.join(save_root_, 'gt_image_rec_rank{}.png'.format(local_rank))
            rec_im.save(save_path)
            
            if tuple(data_i['image'].shape[-2:]) != (256, 256) and 'image256' in data_i:
                if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                    rec = model.module.reconstruct(input=data_i['image256']) # B x C x H x W
                else:
                    rec = model.reconstruct(input=data_i['image256']) # B x C x H x W
                rec_im = Image.fromarray(rec[0].permute(1, 2, 0).to('cpu').numpy().astype(np.uint8))
                save_path = os.path.join(save_root_, 'gt_image256_rec_rank{}.png'.format(local_rank))
                rec_im.save(save_path)

        # generate samples in a batch manner
        count_per_cond_ = len(glob.glob(os.path.join(save_root_, 'rank_*_*_fr*_cr*.png')))
        while count_per_cond_ < count_per_cond:
            fr_ = [r for r in filter_ratio if r not in chosed_fr]
            cr_ = [r for r in content_ratio if r not in chosed_cr]
            fr_ = filter_ratio if len(fr_) == 0 else fr_
            cr_ = content_ratio if len(cr_) == 0 else cr_

            fr = random.choice(fr_)
            cr = random.choice(cr_)

            if fr not in chosed_fr:
                chosed_fr.append(fr)
            if cr not in chosed_cr:
                chosed_cr.append(cr)

            start_batch = time.time()
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                model_out = model.module.generate_content(
                    batch=data_i,
                    filter_ratio=fr,
                    replicate=args.batch_size,
                    content_ratio=cr,
                    return_att_weight=return_att_weight,
                ) # B x C x H x W
            else:
                model_out = model.generate_content(
                    batch=data_i,
                    filter_ratio=fr,
                    replicate=args.batch_size,
                    content_ratio=cr,
                    return_att_weight=return_att_weight,
                ) # B x C x H x W
            # save results
            content = model_out['content']
            content = content.permute(0, 2, 3, 1).to('cpu').numpy().astype(np.uint8)
            for b in range(content.shape[0]):
                cnt = count_per_cond_ + b
                save_base_name = 'rank_{}_{}_fr{}_cr{}'.format(local_rank, str(cnt).zfill(6), fr, cr)
                save_path = os.path.join(save_root_, save_base_name+'.png')
                im = Image.fromarray(content[b])
                im.save(save_path)
                print('Rank {}, Total time {}, batch time {:.2f}s, saved in {}'.format(local_rank, format_seconds(time.time()-start_gen), time.time()-start_batch, save_path))

                if return_att_weight:
                    att_save_dir = os.path.join(save_root_, save_base_name + '_attention')
                    os.makedirs(att_save_dir, exist_ok=True)
                    condition_attention = model_out['condition_attention'].to('cpu') # B x Lt x Ld
                    content_attention = model_out['content_attention'].to('cpu') # B x Lt x H x W
                    cond_att_save_path = os.path.join(att_save_dir, 'condition_attention')
                    cont_att_save_path = os.path.join(att_save_dir, 'content_attention')
                    torch.save(condition_attention, cond_att_save_path+'.pth')
                    torch.save(content_attention, cont_att_save_path+'.pth')
                    
                    cond_att_f = open(cond_att_save_path+'.txt', 'w')
                    cont_att_f = open(cont_att_save_path+'.txt', 'w')

                    for cont_idx in range(content_attention.shape[1]):
                        cond_att_f.write(str(cont_idx)+'\n'+str(condition_attention[b, cont_idx, :])+'\n')
                        cont_att_f.write(str(cont_idx)+'\n'+str(content_attention[b, cont_idx])+'\n')
                        # save content attention as image

                        cont_att_im = (content_attention[b, cont_idx]/content_attention[b, cont_idx].max() * 255).numpy().astype(np.uint8)
                        cont_att_im = Image.fromarray(cont_att_im)
                        cont_att_im.save(os.path.join(att_save_dir, '{}_content_attention.png'.format(cont_idx)))
                    cond_att_f.close()
                    cont_att_f.close()     
            
            print('==> batch time {}s'.format(round(time.time() - start_batch, 1)))
            count_per_cond_ = len(glob.glob(os.path.join(save_root_, 'rank_*_*_fr*_cr*.png')))


def get_args():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--save_dir', type=str, default='RESULT', 
                        help='directory to save results') 

    parser.add_argument('--name', type=str, default='', 
                        help='the name of this experiment, if not provided, set to'
                             'the name of config file') 
    parser.add_argument('--func', type=str, default='inference_reconstruction', 
                        help='the name of inference function') 
    # args for ddp
    parser.add_argument('--num_node', type=int, default=1,
                        help='number of nodes for distributed training')
    parser.add_argument('--node_rank', type=int, default=0,
                        help='node rank for distributed training')
    parser.add_argument('--dist_url', type=str, default='auto', 
                        help='url used to set up distributed training')
    parser.add_argument('--gpu', type=int, default=None,
                        help='GPU id to use. If given, only the specific gpu will be'
                        ' used, and ddp will be disabled')
    parser.add_argument('--batch_size', type=int, default=8,
                        help='batch size while inference')
    parser.add_argument('--data_type', type=str, default='val',
                        choices=['val', 'train'],
                        help='evaluate ema model')                       
    parser.add_argument('--ema', action='store_true', default=False,
                        help='evaluate ema model')
    parser.add_argument('--ema_no_buffer', action='store_true', default=False,
                        help='upadte buffers in ema model')

    parser.add_argument('--debug', action='store_true', # default=True,
                        help='set as debug mode')

    args = parser.parse_args()
    args.cwd = os.path.abspath(os.path.dirname(__file__))

    # modify args for debugging
    if args.debug:
        args.name = 'debug'
        if args.gpu is None:
            args.gpu = 0

    return args


inference_func_map = {
    'inference_reconstruction': inference_reconstruction,
    'inference_token': inference_token,
    'inference_generate_sample_with_condition': inference_generate_sample_with_condition,
    'update_ema_model_buffer': update_ema_model_buffer
}


if __name__ == '__main__':
    args = get_args()

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely disable ddp.')
        torch.cuda.set_device(args.gpu)
        args.ngpus_per_node = 1
        args.world_size = 1
    else:
        if args.num_node == 1:
            args.dist_url == "auto"
        else:
            assert args.num_node > 1
        args.ngpus_per_node = torch.cuda.device_count()
        args.world_size = args.ngpus_per_node * args.num_node

    args.distributed = args.world_size > 1
    
    if args.name == '':
        args.name = 'OUTPUT/dalle_d24h16_PredCond_DalleTextEmbedding_gcc_lr3e-6none_Warmup4.5e-4_plateau_ema_g32/checkpoint/000016e.pth'

    # import pdb; pdb.set_trace()
    if args.func == 'caculate_flops_and_params':
        args.gpu = 0
    launch(inference_func_map[args.func], args.ngpus_per_node, args.num_node, args.node_rank, args.dist_url, args=(args,))

